{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac1187c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys  \n",
    "import json\n",
    "import torch \n",
    "import argparse\n",
    "import numpy as np\n",
    "from PIL import Image  \n",
    "from tqdm import tqdm\n",
    "from utils.utils import model_gen, load_jsonl\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer  \n",
    " \n",
    "ckpt_path = 'internlm/internlm-xcomposer2-vl-7b'\n",
    "tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=\"cuda\", trust_remote_code=True).eval().cuda().half()\n",
    "model.tokenizer = tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a607d317",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, concatenate_datasets   \n",
    "from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG\n",
    "\n",
    "sub_dataset_list = []\n",
    "for subject in CAT_SHORT2LONG.values(): \n",
    "        sub_dataset = load_dataset('MMMU/MMMU', subject, split='validation')\n",
    "        sub_dataset_list.append(sub_dataset)\n",
    "dataset = concatenate_datasets(sub_dataset_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44825fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re, json\n",
    "pattern = re.compile(r'[A-Z]')\n",
    "processd_samples = json.load(open('mmmu_val.json')) \n",
    "results = {}\n",
    "count = 0\n",
    "for idx, sample in tqdm(dataset):\n",
    "    qid = sample['id']\n",
    "    text = processd_samples[qid]['text']\n",
    "    answer_type = processd_samples[qid]['answer_type']\n",
    "    #image = sample['image_1'].convert('RGB') \n",
    "    image = f'/mnt/petrelfs/share_data/dongxiaoyi/share_data/online_eval/mmmu/{idx}.png'\n",
    "    with torch.cuda.amp.autocast():\n",
    "        response = model_gen(model, text, image) \n",
    "    if answer_type == 'choice': \n",
    "        res = pattern.findall(response)\n",
    "        if len(res) == 0:\n",
    "            print('Error:', output_text); res = 'E'\n",
    "        else:\n",
    "            res = res[0]\n",
    "    elif answer_type == 'open': \n",
    "        res = response.lower() \n",
    "    results[qid] = res\n",
    "json.dump(results, open('MMMU_val_InternLM_XComposer_VL.json','w'), indent=2) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0f721aa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
